{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
    "\n",
    "*Licensed under the MIT License.*\n",
    "\n",
    "# Text Classification of MultiNLI Sentences using Multiple Transformer Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import sys\n",
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scrapbook as sb\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from sklearn.metrics import accuracy_score, classification_report\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from tqdm import tqdm\n",
    "from utils_nlp.common.timer import Timer\n",
    "from utils_nlp.common.pytorch_utils import dataloader_from_dataset\n",
    "from utils_nlp.dataset.multinli import load_pandas_df\n",
    "from utils_nlp.models.transformers.sequence_classification import (\n",
    "    Processor, SequenceClassifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "In this notebook, we fine-tune and evaluate a number of pretrained models on a subset of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset.\n",
    "\n",
    "We use a [sequence classifier](../../utils_nlp/models/transformers/sequence_classification.py) that wraps [Hugging Face's PyTorch implementation](https://github.com/huggingface/transformers) of different transformers, like [BERT](https://github.com/google-research/bert), [XLNet](https://github.com/zihangdai/xlnet), and [RoBERTa](https://github.com/pytorch/fairseq)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# notebook parameters\n",
    "DATA_FOLDER = TemporaryDirectory().name\n",
    "CACHE_DIR = TemporaryDirectory().name\n",
    "NUM_EPOCHS = 1\n",
    "BATCH_SIZE = 16\n",
    "NUM_GPUS = 2\n",
    "MAX_LEN = 100\n",
    "TRAIN_DATA_FRACTION = 0.05\n",
    "TEST_DATA_FRACTION = 0.05\n",
    "TRAIN_SIZE = 0.75\n",
    "LABEL_COL = \"genre\"\n",
    "TEXT_COL = \"sentence1\"\n",
    "MODEL_NAMES = [\"distilbert-base-uncased\", \"roberta-base\", \"xlnet-base-cased\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read Dataset\n",
    "We start by loading a subset of the data. The following function also downloads and extracts the files, if they don't exist in the data folder.\n",
    "\n",
    "The MultiNLI dataset is mainly used for natural language inference (NLI) tasks, where the inputs are sentence pairs and the labels are entailment indicators. The sentence pairs are also classified into *genres* that allow for more coverage and better evaluation of NLI models.\n",
    "\n",
    "For our classification task, we use the first sentence only as the text input, and the corresponding genre as the label. We select the examples corresponding to one of the entailment labels (*neutral* in this case) to avoid duplicate rows, as the sentences are not unique, whereas the sentence pairs are."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 222k/222k [01:20<00:00, 2.74kKB/s] \n"
     ]
    }
   ],
   "source": [
    "df = load_pandas_df(DATA_FOLDER, \"train\")\n",
    "df = df[df[\"gold_label\"]==\"neutral\"]  # get unique sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>genre</th>\n",
       "      <th>sentence1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>government</td>\n",
       "      <td>Conceptually cream skimming has two basic dime...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>telephone</td>\n",
       "      <td>yeah i tell you what though if you go price so...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>travel</td>\n",
       "      <td>But a few Christian mosaics survive above the ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>slate</td>\n",
       "      <td>It's not that the questions they asked weren't...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>travel</td>\n",
       "      <td>Thebes held onto power until the 12th Dynasty,...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         genre                                          sentence1\n",
       "0   government  Conceptually cream skimming has two basic dime...\n",
       "4    telephone  yeah i tell you what though if you go price so...\n",
       "6       travel  But a few Christian mosaics survive above the ...\n",
       "12       slate  It's not that the questions they asked weren't...\n",
       "13      travel  Thebes held onto power until the 12th Dynasty,..."
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[[LABEL_COL, TEXT_COL]].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We split the data for training and testing, sample a fraction for faster execution, and encode the class labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/bleik2/backup/.conda/envs/nlp_gpu/lib/python3.6/site-packages/sklearn/model_selection/_split.py:2179: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n",
      "  FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "# split\n",
    "df_train, df_test = train_test_split(df, train_size = TRAIN_SIZE, random_state=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample\n",
    "df_train = df_train.sample(frac=TRAIN_DATA_FRACTION).reset_index(drop=True)\n",
    "df_test = df_test.sample(frac=TEST_DATA_FRACTION).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The examples in the dataset are grouped into 5 genres:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "telephone     1043\n",
       "slate          989\n",
       "fiction        968\n",
       "travel         964\n",
       "government     945\n",
       "Name: genre, dtype: int64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train[LABEL_COL].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# encode labels\n",
    "label_encoder = LabelEncoder()\n",
    "df_train[LABEL_COL] = label_encoder.fit_transform(df_train[LABEL_COL])\n",
    "df_test[LABEL_COL] = label_encoder.transform(df_test[LABEL_COL])\n",
    "\n",
    "num_labels = len(np.unique(df_train[LABEL_COL]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of unique labels: 5\n",
      "Number of training examples: 4909\n",
      "Number of testing examples: 1636\n"
     ]
    }
   ],
   "source": [
    "print(\"Number of unique labels: {}\".format(num_labels))\n",
    "print(\"Number of training examples: {}\".format(df_train.shape[0]))\n",
    "print(\"Number of testing examples: {}\".format(df_test.shape[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Select Pretrained Models\n",
    "\n",
    "Several pretrained models have been made available by [Hugging Face](https://github.com/huggingface/transformers). For text classification, the following pretrained models are supported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>bert-base-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>bert-large-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>bert-base-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>bert-large-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>bert-base-multilingual-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>bert-base-multilingual-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>bert-base-chinese</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>bert-base-german-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>bert-large-uncased-whole-word-masking</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>bert-large-cased-whole-word-masking</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>bert-large-uncased-whole-word-masking-finetune...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>bert-large-cased-whole-word-masking-finetuned-...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>bert-base-cased-finetuned-mrpc</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>bert-base-german-dbmdz-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>bert-base-german-dbmdz-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>bert-base-japanese</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>bert-base-japanese-whole-word-masking</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>bert-base-japanese-char</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>bert-base-japanese-char-whole-word-masking</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>bert-base-finnish-cased-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>bert-base-finnish-uncased-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>roberta-base</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>roberta-large</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>roberta-large-mnli</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>distilroberta-base</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>roberta-base-openai-detector</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>roberta-large-openai-detector</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>xlnet-base-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>xlnet-large-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>distilbert-base-uncased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>distilbert-base-uncased-distilled-squad</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>distilbert-base-german-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>distilbert-base-multilingual-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>albert-base-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>albert-large-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>albert-xlarge-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>albert-xxlarge-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>albert-base-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>albert-large-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>albert-xlarge-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>albert-xxlarge-v2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           model_name\n",
       "0                                   bert-base-uncased\n",
       "1                                  bert-large-uncased\n",
       "2                                     bert-base-cased\n",
       "3                                    bert-large-cased\n",
       "4                      bert-base-multilingual-uncased\n",
       "5                        bert-base-multilingual-cased\n",
       "6                                   bert-base-chinese\n",
       "7                              bert-base-german-cased\n",
       "8               bert-large-uncased-whole-word-masking\n",
       "9                 bert-large-cased-whole-word-masking\n",
       "10  bert-large-uncased-whole-word-masking-finetune...\n",
       "11  bert-large-cased-whole-word-masking-finetuned-...\n",
       "12                     bert-base-cased-finetuned-mrpc\n",
       "13                       bert-base-german-dbmdz-cased\n",
       "14                     bert-base-german-dbmdz-uncased\n",
       "15                                 bert-base-japanese\n",
       "16              bert-base-japanese-whole-word-masking\n",
       "17                            bert-base-japanese-char\n",
       "18         bert-base-japanese-char-whole-word-masking\n",
       "19                         bert-base-finnish-cased-v1\n",
       "20                       bert-base-finnish-uncased-v1\n",
       "21                                       roberta-base\n",
       "22                                      roberta-large\n",
       "23                                 roberta-large-mnli\n",
       "24                                 distilroberta-base\n",
       "25                       roberta-base-openai-detector\n",
       "26                      roberta-large-openai-detector\n",
       "27                                   xlnet-base-cased\n",
       "28                                  xlnet-large-cased\n",
       "29                            distilbert-base-uncased\n",
       "30            distilbert-base-uncased-distilled-squad\n",
       "31                       distilbert-base-german-cased\n",
       "32                 distilbert-base-multilingual-cased\n",
       "33                                     albert-base-v1\n",
       "34                                    albert-large-v1\n",
       "35                                   albert-xlarge-v1\n",
       "36                                  albert-xxlarge-v1\n",
       "37                                     albert-base-v2\n",
       "38                                    albert-large-v2\n",
       "39                                   albert-xlarge-v2\n",
       "40                                  albert-xxlarge-v2"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame({\"model_name\": SequenceClassifier.list_supported_models()})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tune\n",
    "\n",
    "Our wrappers make it easy to fine-tune different models in a unified way, hiding the preprocessing details that are needed before training. In this example, we're going to select the following models and use the same piece of code to fine-tune them on our genre classification task. Note that some models were pretrained on multilingual datasets and can be used with non-English datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['distilbert-base-uncased', 'roberta-base', 'xlnet-base-cased']\n"
     ]
    }
   ],
   "source": [
    "print(MODEL_NAMES)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each pretrained model, we preprocess the data, fine-tune the classifier, score the test set, and store the evaluation results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/bleik2/backup/.conda/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    }
   ],
   "source": [
    "results = {}\n",
    "\n",
    "for model_name in tqdm(MODEL_NAMES, disable=True):\n",
    "\n",
    "    # preprocess\n",
    "    processor = Processor(\n",
    "        model_name=model_name,\n",
    "        to_lower=model_name.endswith(\"uncased\"),\n",
    "        cache_dir=CACHE_DIR,\n",
    "    )\n",
    "    train_dataset = processor.dataset_from_dataframe(\n",
    "        df_train, TEXT_COL, LABEL_COL, max_len=MAX_LEN\n",
    "    )\n",
    "    train_dataloader = dataloader_from_dataset(\n",
    "        train_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=True\n",
    "    )\n",
    "    test_dataset = processor.dataset_from_dataframe(\n",
    "        df_test, TEXT_COL, LABEL_COL, max_len=MAX_LEN\n",
    "    )\n",
    "    test_dataloader = dataloader_from_dataset(\n",
    "        test_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=False\n",
    "    )\n",
    "\n",
    "    # fine-tune\n",
    "    classifier = SequenceClassifier(\n",
    "        model_name=model_name, num_labels=num_labels, cache_dir=CACHE_DIR\n",
    "    )\n",
    "    with Timer() as t:\n",
    "        classifier.fit(\n",
    "            train_dataloader, num_epochs=NUM_EPOCHS, num_gpus=NUM_GPUS, verbose=False,\n",
    "        )\n",
    "    train_time = t.interval / 3600\n",
    "\n",
    "    # predict\n",
    "    preds = classifier.predict(test_dataloader, num_gpus=NUM_GPUS, verbose=False)\n",
    "\n",
    "    # eval\n",
    "    accuracy = accuracy_score(df_test[LABEL_COL], preds)\n",
    "    class_report = classification_report(\n",
    "        df_test[LABEL_COL], preds, target_names=label_encoder.classes_, output_dict=True\n",
    "    )\n",
    "\n",
    "    # save results\n",
    "    results[model_name] = {\n",
    "        \"accuracy\": accuracy,\n",
    "        \"f1-score\": class_report[\"macro avg\"][\"f1-score\"],\n",
    "        \"time(hrs)\": train_time,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate\n",
    "\n",
    "Finally, we report the accuracy and F1-score metrics for each model, as well as the fine-tuning time in hours."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>distilbert-base-uncased</th>\n",
       "      <th>roberta-base</th>\n",
       "      <th>xlnet-base-cased</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>accuracy</th>\n",
       "      <td>0.889364</td>\n",
       "      <td>0.885697</td>\n",
       "      <td>0.886308</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>f1-score</th>\n",
       "      <td>0.885225</td>\n",
       "      <td>0.880926</td>\n",
       "      <td>0.881819</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>time(hrs)</th>\n",
       "      <td>0.023326</td>\n",
       "      <td>0.044209</td>\n",
       "      <td>0.052801</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           distilbert-base-uncased  roberta-base  xlnet-base-cased\n",
       "accuracy                  0.889364      0.885697          0.886308\n",
       "f1-score                  0.885225      0.880926          0.881819\n",
       "time(hrs)                 0.023326      0.044209          0.052801"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results = pd.DataFrame(results)\n",
    "df_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.887123064384678,
       "encoder": "json",
       "name": "accuracy",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "accuracy"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.8826569624491233,
       "encoder": "json",
       "name": "f1",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "f1"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# for testing\n",
    "sb.glue(\"accuracy\", df_results.iloc[0, :].mean())\n",
    "sb.glue(\"f1\", df_results.iloc[1, :].mean())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.8 64-bit ('nlp_gpu': conda)",
   "language": "python",
   "name": "python36864bitnlpgpucondaa579511bcea84c65877ff3dca4205921"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
